Skip to content

[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. #18864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 77 commits into from
Jul 3, 2025

Conversation

bnellnm
Copy link
Contributor

@bnellnm bnellnm commented May 28, 2025

Enable full fp8 support for pplx and BatchedTritonExperts.

  • Replace world_size/dp_size arguments to PrepareAndFinalize and Experts constructors with num_dispatchers.
  • Reduce use of duplicate information for setup, i.e. try to get all the parameters from the FusedMoEConfig rather than all2all_manager or random variables.
  • Rewrote the pplx tests so that they run in a loop on the spawned process rather than spawning a process for each test point. The original slow test points can still be run with the --optional pytest flag.
  • Add a bunch more quantization tests to cover all the combinations of per-token, per-tensor and blocked.

I've verified all the combinations from here work properly: dispatch_combine fp8 support matrix by branch + model.xlsx
with DP=2/TP=1, DP=2/TP=2 and DP=4/TP=1.

lm-eval results for RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic with pplx, DP=4, TP=1.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.86|±  |0.0349|
|     |       |strict-match    |     5|exact_match|↑  | 0.81|±  |0.0394|

cc @ElizaWszola

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 28, 2025
Copy link

mergify bot commented May 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Jun 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Jun 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added needs-rebase qwen Related to Qwen models labels Jun 13, 2025
@bnellnm bnellnm force-pushed the batch-fp8 branch 2 times, most recently from 911339b to f92734e Compare June 24, 2025 21:09
@bnellnm bnellnm mentioned this pull request Jun 25, 2025
@bnellnm bnellnm marked this pull request as ready for review June 26, 2025 21:20
@mergify mergify bot removed the needs-rebase label Jun 26, 2025
@bnellnm bnellnm changed the title [Kernel] Fix fp8 support for pplx and BatchedTritonExperts. [Kernel] Enable fp8 support for pplx and BatchedTritonExperts. Jun 26, 2025
Copy link

mergify bot commented Jun 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

bnellnm added 3 commits July 2, 2025 16:04
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
@varun-sundar-rabindranath
Copy link
Contributor

LGTM! Really nice cleanups @bnellnm 🙌

@bnellnm
Copy link
Contributor Author

bnellnm commented Jul 2, 2025

LGTM! Really nice cleanups @bnellnm 🙌

Thanks!

Signed-off-by: Bill Nell <bnell@redhat.com>
bnellnm added 3 commits July 2, 2025 23:24
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 3, 2025 02:52
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 3, 2025
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this only False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left it here for future testing,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Should there be also a condition in the test code to skip the test if input_scales == True and quant_dtype is None?

Copy link
Contributor Author

@bnellnm bnellnm Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one of the conditions that needs more testing. There's some int8/int4 quantization schemes that happen outside the triton kernels. So they need to pass in the quantized data + scales, but no quant_type since they are already quantized.

@@ -178,6 +175,8 @@ def run_cutlass_moe_fp8(
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))

c1.fill_(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a condition here that we only zero-out c1 if expert_map is not none and per_act_token == True? As far as I'm aware, this is the only case when it's needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another PR that has the proper condition for this. I don't want to have to rerun everything at this point. I'll let that other PR push the better fix.

@simon-mo simon-mo disabled auto-merge July 3, 2025 21:55
@simon-mo simon-mo merged commit 78fe775 into vllm-project:main Jul 3, 2025
87 of 91 checks passed
sfeng33 pushed a commit to sfeng33/vllm that referenced this pull request Jul 6, 2025
huydhn pushed a commit to huydhn/vllm that referenced this pull request Jul 8, 2025
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants